今天為最後一天的主題,主要實現 API Gateway 的功能,當然因為以學習為目標,
所以未必會比仿間做 Loadbalancer ,Api Gateway 考慮的面向還多,今天是最後一天了
努力加油!!
api-gateway/
├── Cargo.toml
├── config.yaml
└── src/
├── main.rs
├── config.rs
├── proxy.rs
├── load_balancer.rs
├── health_check.rs
├── rate_limiter.rs
└── circuit_breaker.rs
cargo.toml
[package]
name = "api-gateway"
version = "0.1.0"
edition = "2021"
[dependencies]
tokio = { version = "1.35", features = ["full"] }
axum = "0.7"
hyper = { version = "1.0", features = ["full"] }
hyper-util = { version = "0.1", features = ["full"] }
tower = "0.4"
serde = { version = "1.0", features = ["derive"] }
serde_yaml = "0.9"
tracing = "0.1"
tracing-subscriber = "0.3"
tokio-util = "0.7"
http-body-util = "0.1"
bytes = "1.5"
governor = "0.6"
parking_lot = "0.12"
src/config.rs
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Config {
pub server: ServerConfig,
pub services: HashMap<String, ServiceConfig>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ServerConfig {
pub host: String,
pub port: u16,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ServiceConfig {
pub path_prefix: String,
pub backends: Vec<Backend>,
pub load_balancer: LoadBalancerType,
pub health_check: HealthCheckConfig,
pub rate_limit: Option<RateLimitConfig>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Backend {
pub url: String,
pub weight: Option<u32>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum LoadBalancerType {
RoundRobin,
LeastConnections,
WeightedRoundRobin,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct HealthCheckConfig {
pub interval_secs: u64,
pub timeout_secs: u64,
pub path: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RateLimitConfig {
pub requests_per_second: u32,
}
impl Config {
pub fn from_file(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
let content = std::fs::read_to_string(path)?;
let config: Config = serde_yaml::from_str(&content)?;
Ok(config)
}
}
src/load_balancer.rs
use crate::config::{Backend, LoadBalancerType};
use parking_lot::RwLock;
use std::sync::Arc;
#[derive(Clone)]
pub struct LoadBalancer {
backends: Arc<RwLock<Vec<BackendState>>>,
strategy: LoadBalancerType,
current_index: Arc<RwLock<usize>>,
}
#[derive(Clone)]
struct BackendState {
backend: Backend,
healthy: bool,
active_connections: usize,
}
impl LoadBalancer {
pub fn new(backends: Vec<Backend>, strategy: LoadBalancerType) -> Self {
let backend_states = backends
.into_iter()
.map(|b| BackendState {
backend: b,
healthy: true,
active_connections: 0,
})
.collect();
Self {
backends: Arc::new(RwLock::new(backend_states)),
strategy,
current_index: Arc::new(RwLock::new(0)),
}
}
pub fn next_backend(&self) -> Option<String> {
let backends = self.backends.read();
let healthy_backends: Vec<_> = backends
.iter()
.filter(|b| b.healthy)
.collect();
if healthy_backends.is_empty() {
return None;
}
match self.strategy {
LoadBalancerType::RoundRobin => self.round_robin(&healthy_backends),
LoadBalancerType::LeastConnections => self.least_connections(&healthy_backends),
LoadBalancerType::WeightedRoundRobin => self.weighted_round_robin(&healthy_backends),
}
}
fn round_robin(&self, backends: &[&BackendState]) -> Option<String> {
let mut index = self.current_index.write();
let backend = backends.get(*index % backends.len())?;
*index += 1;
Some(backend.backend.url.clone())
}
fn least_connections(&self, backends: &[&BackendState]) -> Option<String> {
backends
.iter()
.min_by_key(|b| b.active_connections)
.map(|b| b.backend.url.clone())
}
fn weighted_round_robin(&self, backends: &[&BackendState]) -> Option<String> {
let total_weight: u32 = backends
.iter()
.map(|b| b.backend.weight.unwrap_or(1))
.sum();
if total_weight == 0 {
return self.round_robin(backends);
}
let mut index = self.current_index.write();
let position = (*index % total_weight as usize) as u32;
*index += 1;
let mut cumulative = 0u32;
for backend in backends {
cumulative += backend.backend.weight.unwrap_or(1);
if position < cumulative {
return Some(backend.backend.url.clone());
}
}
backends.first().map(|b| b.backend.url.clone())
}
pub fn increment_connections(&self, url: &str) {
let mut backends = self.backends.write();
if let Some(backend) = backends.iter_mut().find(|b| b.backend.url == url) {
backend.active_connections += 1;
}
}
pub fn decrement_connections(&self, url: &str) {
let mut backends = self.backends.write();
if let Some(backend) = backends.iter_mut().find(|b| b.backend.url == url) {
backend.active_connections = backend.active_connections.saturating_sub(1);
}
}
pub fn mark_unhealthy(&self, url: &str) {
let mut backends = self.backends.write();
if let Some(backend) = backends.iter_mut().find(|b| b.backend.url == url) {
backend.healthy = false;
}
}
pub fn mark_healthy(&self, url: &str) {
let mut backends = self.backends.write();
if let Some(backend) = backends.iter_mut().find(|b| b.backend.url == url) {
backend.healthy = true;
}
}
}
src/health_check.rs
use crate::config::HealthCheckConfig;
use crate::load_balancer::LoadBalancer;
use std::time::Duration;
use tokio::time;
use tracing::{error, info};
pub struct HealthChecker {
load_balancer: LoadBalancer,
config: HealthCheckConfig,
backends: Vec<String>,
}
impl HealthChecker {
pub fn new(
load_balancer: LoadBalancer,
config: HealthCheckConfig,
backends: Vec<String>,
) -> Self {
Self {
load_balancer,
config,
backends,
}
}
pub async fn start(self) {
let mut interval = time::interval(Duration::from_secs(self.config.interval_secs));
loop {
interval.tick().await;
self.check_all_backends().await;
}
}
async fn check_all_backends(&self) {
for backend_url in &self.backends {
let health_url = format!("{}{}", backend_url, self.config.path);
match self.check_backend(&health_url).await {
Ok(true) => {
info!("Backend {} is healthy", backend_url);
self.load_balancer.mark_healthy(backend_url);
}
Ok(false) | Err(_) => {
error!("Backend {} is unhealthy", backend_url);
self.load_balancer.mark_unhealthy(backend_url);
}
}
}
}
async fn check_backend(&self, url: &str) -> Result<bool, Box<dyn std::error::Error>> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(self.config.timeout_secs))
.build()?;
let response = client.get(url).send().await?;
Ok(response.status().is_success())
}
}
src/rate_limiter.rs
use governor::{Quota, RateLimiter as GovernorRateLimiter};
use std::num::NonZeroU32;
use std::sync::Arc;
#[derive(Clone)]
pub struct RateLimiter {
limiter: Arc<GovernorRateLimiter<String, governor::state::direct::NotKeyed, governor::clock::DefaultClock>>,
}
impl RateLimiter {
pub fn new(requests_per_second: u32) -> Self {
let quota = Quota::per_second(NonZeroU32::new(requests_per_second).unwrap());
let limiter = Arc::new(GovernorRateLimiter::direct(quota));
Self { limiter }
}
pub fn check(&self) -> bool {
self.limiter.check().is_ok()
}
}
src/circuit_breaker.rs
use parking_lot::RwLock;
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Clone)]
pub struct CircuitBreaker {
state: Arc<RwLock<CircuitState>>,
failure_threshold: usize,
timeout: Duration,
}
struct CircuitState {
failures: usize,
last_failure: Option<Instant>,
state: State,
}
#[derive(PartialEq, Clone, Copy)]
enum State {
Closed,
Open,
HalfOpen,
}
impl CircuitBreaker {
pub fn new(failure_threshold: usize, timeout_secs: u64) -> Self {
Self {
state: Arc::new(RwLock::new(CircuitState {
failures: 0,
last_failure: None,
state: State::Closed,
})),
failure_threshold,
timeout: Duration::from_secs(timeout_secs),
}
}
pub fn can_request(&self) -> bool {
let mut state = self.state.write();
match state.state {
State::Closed => true,
State::Open => {
if let Some(last_failure) = state.last_failure {
if last_failure.elapsed() > self.timeout {
state.state = State::HalfOpen;
true
} else {
false
}
} else {
false
}
}
State::HalfOpen => true,
}
}
pub fn record_success(&self) {
let mut state = self.state.write();
state.failures = 0;
state.state = State::Closed;
}
pub fn record_failure(&self) {
let mut state = self.state.write();
state.failures += 1;
state.last_failure = Some(Instant::now());
if state.failures >= self.failure_threshold {
state.state = State::Open;
}
}
}
src/proxy.rs
use crate::circuit_breaker::CircuitBreaker;
use crate::load_balancer::LoadBalancer;
use crate::rate_limiter::RateLimiter;
use axum::{
body::Body,
extract::State,
http::{Request, Response, StatusCode},
response::IntoResponse,
};
use hyper_util::client::legacy::Client;
use hyper_util::rt::TokioExecutor;
use std::sync::Arc;
use tracing::{error, info};
#[derive(Clone)]
pub struct ProxyState {
pub load_balancer: LoadBalancer,
pub rate_limiter: Option<RateLimiter>,
pub circuit_breaker: CircuitBreaker,
pub client: Client<hyper_util::client::legacy::connect::HttpConnector, Body>,
}
impl ProxyState {
pub fn new(
load_balancer: LoadBalancer,
rate_limiter: Option<RateLimiter>,
) -> Self {
let client = Client::builder(TokioExecutor::new()).build_http();
let circuit_breaker = CircuitBreaker::new(5, 30);
Self {
load_balancer,
rate_limiter,
circuit_breaker,
client,
}
}
}
pub async fn proxy_handler(
State(state): State<Arc<ProxyState>>,
mut req: Request<Body>,
) -> impl IntoResponse {
// 檢查限流
if let Some(ref limiter) = state.rate_limiter {
if !limiter.check() {
return Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.body(Body::from("Rate limit exceeded"))
.unwrap();
}
}
// 檢查熔斷器
if !state.circuit_breaker.can_request() {
return Response::builder()
.status(StatusCode::SERVICE_UNAVAILABLE)
.body(Body::from("Service temporarily unavailable"))
.unwrap();
}
// 選擇後端服務
let backend_url = match state.load_balancer.next_backend() {
Some(url) => url,
None => {
return Response::builder()
.status(StatusCode::SERVICE_UNAVAILABLE)
.body(Body::from("No healthy backends available"))
.unwrap();
}
};
// 修改請求 URI
let path = req.uri().path();
let path_query = req
.uri()
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or(path);
let target_url = format!("{}{}", backend_url, path_query);
match target_url.parse::<hyper::Uri>() {
Ok(uri) => {
*req.uri_mut() = uri;
}
Err(e) => {
error!("Failed to parse URI: {}", e);
return Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from("Invalid backend URL"))
.unwrap();
}
}
// 增加連接計數
state.load_balancer.increment_connections(&backend_url);
// 發送請求
let response = match state.client.request(req).await {
Ok(resp) => {
info!("Request forwarded to {}", backend_url);
state.circuit_breaker.record_success();
resp
}
Err(e) => {
error!("Proxy error: {}", e);
state.circuit_breaker.record_failure();
state.load_balancer.decrement_connections(&backend_url);
return Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Body::from("Backend service error"))
.unwrap();
}
};
// 減少連接計數
state.load_balancer.decrement_connections(&backend_url);
response
}
mod circuit_breaker;
mod config;
mod health_check;
mod load_balancer;
mod proxy;
mod rate_limiter;
use axum::{routing::any, Router};
use config::Config;
use health_check::HealthChecker;
use load_balancer::LoadBalancer;
use proxy::{proxy_handler, ProxyState};
use rate_limiter::RateLimiter;
use std::sync::Arc;
use tracing::info;
use tracing_subscriber;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// 初始化日誌
tracing_subscriber::fmt::init();
// 載入配置
let config = Config::from_file("config.yaml")?;
let addr = format!("{}:{}", config.server.host, config.server.port);
info!("Starting API Gateway on {}", addr);
// 為每個服務創建路由
let mut app = Router::new();
for (service_name, service_config) in config.services {
info!("Configuring service: {}", service_name);
// 創建負載均衡器
let load_balancer = LoadBalancer::new(
service_config.backends.clone(),
service_config.load_balancer.clone(),
);
// 創建限流器
let rate_limiter = service_config
.rate_limit
.as_ref()
.map(|rl| RateLimiter::new(rl.requests_per_second));
// 創建代理狀態
let proxy_state = Arc::new(ProxyState::new(load_balancer.clone(), rate_limiter));
// 添加路由
let path = format!("{}/*path", service_config.path_prefix);
app = app.route(&path, any(proxy_handler).with_state(proxy_state));
// 啟動健康檢查
let backend_urls: Vec<String> = service_config
.backends
.iter()
.map(|b| b.url.clone())
.collect();
let health_checker = HealthChecker::new(
load_balancer,
service_config.health_check,
backend_urls,
);
tokio::spawn(async move {
health_checker.start().await;
});
}
// 啟動服務器
let listener = tokio::net::TcpListener::bind(&addr).await?;
info!("API Gateway listening on {}", addr);
axum::serve(listener, app).await?;
Ok(())
}
config.yaml
server:
host: "0.0.0.0"
port: 8080
services:
user_service:
path_prefix: "/api/users"
load_balancer: weighted_round_robin
backends:
- url: "http://localhost:3001"
weight: 3
- url: "http://localhost:3002"
weight: 2
- url: "http://localhost:3003"
weight: 1
health_check:
interval_secs: 10
timeout_secs: 3
path: "/health"
rate_limit:
requests_per_second: 100
order_service:
path_prefix: "/api/orders"
load_balancer: least_connections
backends:
- url: "http://localhost:4001"
- url: "http://localhost:4002"
health_check:
interval_secs: 15
timeout_secs: 5
path: "/health"
rate_limit:
requests_per_second: 50
product_service:
path_prefix: "/api/products"
load_balancer: round_robin
backends:
- url: "http://localhost:5001"
- url: "http://localhost:5002"
- url: "http://localhost:5003"
health_check:
interval_secs: 10
timeout_secs: 3
path: "/health"
簡單的測試後端
use axum::{routing::get, Router};
#[tokio::main]
async fn main() {
let app = Router::new()
.route("/health", get(|| async { "OK" }))
.route("/api/users/*path", get(|| async { "User Service Response" }));
let listener = tokio::net::TcpListener::bind("0.0.0.0:3001")
.await
.unwrap();
println!("Backend service running on :3001");
axum::serve(listener, app).await.unwrap();
}
ouse axum::{
http::{Request, StatusCode},
middleware::Next,
response::Response,
};
pub async fn auth_middleware<B>(
req: Request<B>,
next: Next<B>,
) -> Result<Response, StatusCode> {
// 檢查 Authorization header
let auth_header = req
.headers()
.get("Authorization")
.and_then(|h| h.to_str().ok());
match auth_header {
Some(token) if token.starts_with("Bearer ") => {
// 驗證 token
Ok(next.run(req).await)
}
_ => Err(StatusCode::UNAUTHORIZED),
}
}
use axum::middleware;
use tower_http::trace::TraceLayer;
// 在 Router 中添加
app = app.layer(TraceLayer::new_for_http());
use axum::extract::ws::{WebSocket, WebSocketUpgrade};
async fn ws_handler(
ws: WebSocketUpgrade,
State(state): State<Arc<ProxyState>>,
) -> impl IntoResponse {
ws.on_upgrade(|socket| handle_socket(socket, state))
}
async fn handle_socket(socket: WebSocket, state: Arc<ProxyState>) {
// WebSocket 代理邏輯
}
經過這 30 天我個人認為我了解到很多東西,雖然工作繁忙,還在寫這些其實過程非常痛苦
尤其 15 - 25 這區間是最痛苦的階段,但也慢慢變成習慣每天發兩篇文章這樣
生理時鐘會有雷達告訴我這時候該寫文了這樣。
這 30 天的旅程也是很有收穫,冥冥之中也做了不少事情。
我今年的鐵人賽沒有放額外的圖和講解,因為我發現如果寫得太困難會沒人看
寫得太簡單我會覺得無聊,但太簡單又會想補充很多東西,所以到後來我都放飛自我
志在參加而非得獎,我認為今年我應該也不會有什麼得獎XD
如果我得獎我就把這段給刪掉XD
因為我寫得是最樸素的寫法,並沒有準備圖片或是宣傳語之類的,也是完全按照個人意思寫文章
所以我認為我高機率不會得獎,因為我覺得不會有人被我文章所吸引,聽聞鐵人賽還是有一大基數
是面對初學者居多,那我就知道我不會得獎,我去選擇當『進階技術』的那類文章即可。